{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cadd7f7b",
   "metadata": {},
   "source": [
    "## 8.6 DenseNet概述"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "909daaec",
   "metadata": {},
   "source": [
    "论文可见于：[论文链接](https://arxiv.org/abs/1608.06993)\n",
    "\n",
    "DenseNet（Densely connected convolutional networks）是由Gao Huang等人在2017年提出的一种深度卷积神经网络，它在ResNet的基础上进行了改进，其中使用了联结不同层的稠密块（dense blocks）来构建模型。DenseNet的一个主要优势在于，它可以有效地减少模型的参数数量，同时保持或者提高模型的性能。这是通过在每个层之间使用\"短接（shortcut connections）\"来实现的，这些短接使得每个层都可以直接访问之前所有层的特征。\n",
    "\n",
    "DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用（feature reuse）。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能，DenseNet也因此斩获CVPR 2017的最佳论文奖。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15fa7389",
   "metadata": {},
   "source": [
    "### 8.6.1 DenseNet基本思想"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f86c6667",
   "metadata": {},
   "source": [
    "DenseNet核心思想在于建立了不同层之间的连接关系，充分利用了feature，进一步减轻了梯度消失问题，加深网络不是问题，而且训练效果非常好。另外通过利用bottleneck layer，Translation layer以及较小的growth rate使得网络变窄，参数减少，有效抑制了过拟合，同时计算量也减少了，在和ResNet的对比中优势还是比较明显。\n",
    "\n",
    "相比ResNet，DenseNet提出了一个更激进的密集连接机制。即互相连接所有的层，具体来说就是每个层都会接受其前面所有层作为其额外的输入。其dense block结构如下图所示。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd3df9cf",
   "metadata": {},
   "source": [
    "<img src=\"./images/8-6-1.png\" width=\"40%\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b3f6a6c",
   "metadata": {},
   "source": [
    "将一个个dense block结构拼接之后，DenseNet的密集连接机制如下图所示。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c627fe63",
   "metadata": {},
   "source": [
    "<img src=\"./images/8-6-2.png\" width=\"100%\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "975cd910",
   "metadata": {},
   "source": [
    "相比之下，ResNet是每个层与前面的某层短路连接在一起，连接方式是通过元素级相加。而在DenseNet中，每个层都会与前面所有层在channel维度上连接（concat）在一起，并作为下一层的输入。也就是说，对于一个 $L$ 层的网络，DenseNet共包含 $\\frac{L(L+1)}{2}$ 个连接，而不是传统意义上的L层连接。因为密集连接的特性，称其为DenseNet。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "573f54e8",
   "metadata": {},
   "source": [
    "### 8.6.2 DenseNet结构"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05e963ed",
   "metadata": {},
   "source": [
    "以DenseNet-121为例，其结构示意图如下图所示："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "608fd72d",
   "metadata": {},
   "source": [
    "<img src=\"./images/8-6-3.png\" width=\"80%\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8212a619",
   "metadata": {},
   "source": [
    "不同深度的DenseNet结构如下图所示："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9783afe9",
   "metadata": {},
   "source": [
    "<img src=\"./images/8-6-4.png\" width=\"80%\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fa08516",
   "metadata": {},
   "source": [
    "看上面的结构图好像也很复杂，但实际上对照上表就是几部分：\n",
    "\n",
    "1. 第一部分卷积层+最大池化层；\n",
    "2. 第二部分是4组不同数量和规格的Dense Block；\n",
    "3. 第三部分是3个Transition Layer，主要作用就是降维和调整特征图size；\n",
    "3. 第四部分平均池化层+全连接层。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b152170",
   "metadata": {},
   "source": [
    "### 8.6.3 DenseNet代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12e696de",
   "metadata": {},
   "source": [
    "接下来看一下如何用代码实现相关网络结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d27ca2df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的库\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32fc4b69",
   "metadata": {},
   "source": [
    "首先是DenseBlock中的内部结构，这里是BN+ReLU+Conv1x1+BN+ReLU+Conv3x3的结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "852700ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "class _DenseLayer(nn.Module):\n",
    "    def __init__(self, num_input_features, growth_rate, bn_size):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.BatchNorm2d(num_input_features),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)\n",
    "        )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.BatchNorm2d(bn_size * growth_rate),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.conv2(out)\n",
    "        return torch.cat([x, out], 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cc178a4",
   "metadata": {},
   "source": [
    "然后实现DenseBlock模块，内部是密集连接方式（输入特征数线性增长）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c7089c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "class _DenseBlock(nn.Module):\n",
    "    def __init__(self, num_layers, num_input_features, bn_size, growth_rate):\n",
    "        super().__init__()\n",
    "        \n",
    "        layers = []\n",
    "        for i in range(num_layers):\n",
    "            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size)\n",
    "            layers.append(layer)\n",
    "        self.block = nn.Sequential(*layers)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.block(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8840bc8a",
   "metadata": {},
   "source": [
    "接下来实现Transition层，它是BN+ReLU+Conv1x1+Pool的结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "789ca9b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class _Transition(nn.Module):\n",
    "    def __init__(self, num_input_features, num_output_features):\n",
    "        super().__init__()\n",
    "        self.trans = nn.Sequential(\n",
    "            nn.BatchNorm2d(num_input_features),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False),\n",
    "            nn.AvgPool2d(kernel_size=2, stride=2)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.trans(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "300fabed",
   "metadata": {},
   "source": [
    "然后组装DenseNet的网络结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4b0b4118",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseNet(nn.Module):\n",
    "    def __init__(self, block_config, growth_rate=32, num_init_features=64, bn_size=4, num_classes=1000):\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        # First convolution\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),\n",
    "            nn.BatchNorm2d(num_init_features),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        )\n",
    "\n",
    "        # Each denseblock\n",
    "        num_features = num_init_features\n",
    "        layers = []\n",
    "        for i, num_layers in enumerate(block_config):\n",
    "            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate)\n",
    "            layers.append(block)\n",
    "            num_features = num_features + num_layers * growth_rate\n",
    "            if i != len(block_config) - 1:\n",
    "                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)\n",
    "                layers.append(trans)\n",
    "                num_features = num_features // 2\n",
    "        layers.append(nn.BatchNorm2d(num_features))\n",
    "        self.denseblock = nn.Sequential(*layers)\n",
    "\n",
    "        # Linear layer\n",
    "        self.classifier = nn.Linear(num_features, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.features(x)\n",
    "        features = self.denseblock(features)\n",
    "        out = F.relu(features, inplace=True)\n",
    "        out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)\n",
    "        out = self.classifier(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50614d99",
   "metadata": {},
   "source": [
    "最后封装不同深度模型的函数，直接调用即可。block_config参数的组合与上文中表格是一致的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "45faa4eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 封装函数\n",
    "def densenet121():\n",
    "    return DenseNet(block_config=(6, 12, 24, 16), growth_rate=32, num_init_features=64)\n",
    "\n",
    "def densenet161():\n",
    "    return DenseNet(block_config=(6, 12, 36, 24), growth_rate=48, num_init_features=96)\n",
    "\n",
    "def densenet169():\n",
    "    return DenseNet(block_config=(6, 12, 32, 32), growth_rate=32, num_init_features=64)\n",
    "\n",
    "def densenet201():\n",
    "    return DenseNet(block_config=(6, 12, 48, 32), growth_rate=32, num_init_features=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8429db2",
   "metadata": {},
   "source": [
    "### 8.6.4 小结"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "208bbc66",
   "metadata": {},
   "source": [
    "* DenseNet 和 ResNet 非常的相似，简单来看，可以理解为 ResNet 的特征图是按位置相加，而DenseNet的特征图是进行concat拼接。但实际上基于密集连接的结构，能够形成更复杂的特征模式。\n",
    "* DenseNet 可以有效地减少模型的参数数量，保持或者提高模型的性能，同时通过特征在channel上的连接来实现特征重用。\n",
    "* DenseNet 结构可以有效环节梯度消失和模型退化的问题。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c3fb22b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
